그래프를 그리기 위해서 matplotlib
을 임포트 합니다. %matplotlib inline
은 새로운 창을 띄우지 않고 주피터 노트북 안에 이미지를 삽입하여 줍니다.
In [ ]:
import matplotlib.pyplot as plt
%matplotlib inline
텐서플로우를 tf
란 이름으로 임포트 하세요.
In [ ]:
tf.Session()
을 사용하여 세션 객체를 하나 만드세요.
sess = tf.Session()
In [ ]:
임의의 샘플 데이터를 만들려고 합니다. 평균 0, 표준 편차 0.55 인 샘플 데이터 1000개를 만듭니다.
x_raw = tf.random_normal([...], mean=.., stddev=..)
x = sess.run(x_raw)
In [ ]:
x_raw = ...
x = ...
위에서 x 축의 값을 만들었으니 이에 상응하는 y 축의 값을 만들려고 합니다. y 값은 0.1*x+0.3 을 만족하되 실제 데이터처럼 보이게 하려고 난수를 조금 섞어서 만듭니다. 여기서는 평균 0, 표준 편차 0.03 인 정규 분포 난수를 만듭니다.
y_raw = 0.1 * x + 0.3 + tf.random_normal([...], mean=.., stddev=..)
y = sess.run(y_raw)
In [ ]:
y_raw = ...
y = ...
만든 샘플 데이터를 산점도로 나타내보겠습니다. plot 명령에 x, y 축의 값을 전달하고 산점도 표시는 원 모양 'o'
으로 하고 테두리 선을 검은색으로 그리도록 하겠습니다.
plt.plot(x, y, 'o', markeredgecolor='k')
In [ ]:
선형 회귀에서 사용할 두개의 변수 W 와 b 를 만들고 직선 방정식을 구성합니다.
W = tf.Variable(tf.zeros([.]))
b = ...(tf.zeros([.]))
y_hat = W * x + b
In [ ]:
W = ...
b = ...
y_hat = ...
회귀에서의 손실함수는 평균 제곱 오차(mean squared error)입니다. 텐서플로우에서 사용하는 오차 함수 tf.loss.mean_squared_error()
를 사용하여 손실 함수를 위한 노드를 만듭니다. 이 함수에 전달할 매개변수는 정답 y와 예측한 값 y_hat 입니다.
loss = tf.losses.mean_squared_error(y, y_hat)
경사하강법은 텐서플로우 tf.train.GradientDescentOptimizer()
에 구현되어 있습니다. 경사하강법 학습속도를 0.5로 주고 최적화 연산을 만듭니다.
optimizer = tf.train.GradientDescentOptimizer(0.5)
optimizer.minimize()
함수에 손실 함수 객체를 넘겨주어 학습할 최종 객체를 생성합니다.
train = optimizer.minimize(loss)
In [ ]:
loss = ...
optimizer = ...
train = ...
계산 그래프에 필요한 변수를 초기화합니다.
In [ ]:
init = ...
sess.run(init)
sess.run()
메소드를 이용해 필요한 연산을 수행할 수 있습니다. 반드시 수행할 것은 train
이고 화면 출력을 위해 W, b, loss 를 계산해서 값을 반환 받겠습니다.
_, w_, b_, c = sess.run([train, W, b, loss])
반환 받은 c 는 costs 리스트에 추가하여 나중에 손실함수 그래프를 그리겠습니다. w, b 를 이용해 위 산점도에 직선이 어떻게 맞춰지는지 그림으로 표현합니다.
plt.plot(x, w_ * x + b_)
In [ ]:
costs = []
for step in range(10):
_, w_, b_, c = ...
costs.append(c)
print(step, w_, b_, c)
# 산포도 그리기
plt.plot(x, y, 'o', markeredgecolor='k')
# 직선 그리기
plt.plot(...)
# x, y 축 레이블링을 하고 각 축의 최대, 최소값 범위를 지정합니다.
plt.xlabel('x')
plt.xlim(-2,2)
plt.ylim(0.1,0.6)
plt.ylabel('y')
plt.show()
손실 함수 값이 들어 있는 costs
리스트를 이용해 그래프를 출력합니다.
plt.plot(costs)
In [ ]: